{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-0BQWhvAP2jb"
      },
      "source": [
        "\n",
        "\u003ca href=\"https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bqZYp90PIa1t"
      },
      "source": [
        "# Overview\n",
        "\n",
        "This is the third Colab in a [series of tutorials on how to use T5X](https://github.com/google-research/t5x/blob/main/docs/tutorials.md). We assume that you have already completed the [Introductory Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) and the [Training Deep Dive](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb), or have a basic understanding of the T5X models, checkpoints, partitioner, trainer, and `InteractiveModel`.\n",
        "\n",
        "In the [previous Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series, we dove into how the InteractiveModel restores models from checkpoints and runs training, while also getting an introduction to the T5X trainer. In this Colab, we will focus on how the `InteractiveModel` does decoding to generate predictions and scores for a given input. It should be noted that the code snippets below exactly replicate the InteractiveModel `__init__()` and `infer_with_preprocessors()` methods (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); we expose this functionality here in order to demonstrate how various components of the T5X codebase work together to run inference on a model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nZJbWZcfkyxI"
      },
      "source": [
        "# Set-Up"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8my9yhSRi6GG"
      },
      "source": [
        "Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jIGSIHzD7YPO"
      },
      "outputs": [],
      "source": [
        "from collections.abc import Sequence\n",
        "import enum\n",
        "import functools\n",
        "import inspect\n",
        "import itertools\n",
        "import logging\n",
        "import os\n",
        "import re\n",
        "from typing import Any, Callable, Iterator, Optional, Tuple, Union\n",
        "\n",
        "import jax\n",
        "from jax import random\n",
        "from jax.experimental import multihost_utils\n",
        "import numpy as np\n",
        "import seqio\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import t5.data\n",
        "import t5.data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yNtayQIxjEBd"
      },
      "outputs": [],
      "source": [
        "import clu.data\n",
        "from t5x.examples.t5 import network\n",
        "import t5x\n",
        "from t5x import models\n",
        "from t5x import partitioning\n",
        "from t5x import trainer as trainer_lib\n",
        "from t5x import utils\n",
        "from t5x.infer import _extract_tokens_and_aux_values\n",
        "from t5x.infer import _Inferences\n",
        "from t5x.interactive_model import InteractiveModel\n",
        "from t5x.interactive_model import get_batches_from_seqio\n",
        "from t5x.interactive_model import get_dataset_from_natural_text_examples\n",
        "from t5x.interactive_model import get_gin_config_from_interactive_model\n",
        "from t5x.interactive_model import T5XScriptType\n",
        "from t5x.interactive_model import InferenceType"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Lb-Z1fkF5a"
      },
      "source": [
        "Before we begin, let's initialize instances of the constructor arguments for the `InteractiveModel`. As mentioned previously, this will enable us to dive into how the `InteractiveModel` runs inference.\n",
        "\n",
        "If you don't understand the lines of code below, or have questions about how to initialize these parameters, please see the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ne8U8qoWkX_r"
      },
      "outputs": [],
      "source": [
        "# Define a model. The configuration below corresponds to the T5 1.1 Small model.\n",
        "t5_config = network.T5Config(\n",
        "    vocab_size=32128,\n",
        "    dtype='bfloat16',\n",
        "    emb_dim=512,\n",
        "    num_heads=6,\n",
        "    num_encoder_layers=8,\n",
        "    num_decoder_layers=8,\n",
        "    head_dim=64,\n",
        "    mlp_dim=1024,\n",
        "    mlp_activations=('gelu', 'linear'),\n",
        "    dropout_rate=0.0,\n",
        "    logits_via_embedding=False)\n",
        "module = network.Transformer(config=t5_config)\n",
        "model = t5x.models.EncoderDecoderModel(\n",
        "    module=module,\n",
        "    input_vocabulary=t5.data.get_default_vocabulary(),\n",
        "    output_vocabulary=t5.data.get_default_vocabulary(),\n",
        "    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))\n",
        "# Define checkpoint arguments.\n",
        "checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'\n",
        "dtype='bfloat16'\n",
        "restore_mode='specific'\n",
        "# Define a partitioner.\n",
        "partitioner=partitioning.PjitPartitioner(num_partitions=2)\n",
        "# Define additional, miscellaneous constructor arguments.\n",
        "batch_size=8\n",
        "task_feature_lengths = {'inputs': 38, 'targets': 18}\n",
        "output_dir='/tmp/output_dir'\n",
        "input_shapes = {\n",
        "    'encoder_input_tokens': np.array([8, 38]),\n",
        "    'decoder_target_tokens': np.array([8, 18]),\n",
        "    'decoder_input_tokens': np.array([8, 18]),\n",
        "    'decoder_loss_weights': np.array([8, 18])\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EYwdg-fFTU8Q"
      },
      "source": [
        "In addition, we will run all code that is performed when we initialize the InteractiveModel. If you don't understand the lines of code below or have any additional questions about how/why we do the steps below, please see the [second Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YmGTJBAcTpMR"
      },
      "outputs": [],
      "source": [
        "# 1.) Configure the Output Directory\n",
        "output_dir = re.sub(r\"(?\u003c!gs:)([\\/]{2,})\", \"/\", output_dir)\n",
        "if not os.path.exists(output_dir):\n",
        "  os.mkdir(output_dir)\n",
        "\n",
        "# 2.) Initialize RNGs\n",
        "init_random_seed = 42\n",
        "random_seed = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed))\n",
        "utils.set_hardware_rng_ops()\n",
        "rng = random.PRNGKey(random_seed)\n",
        "init_rng, trainer_rng = random.split(rng, 2)\n",
        "\n",
        "# 3.) Validate the Partitioner\n",
        "if partitioner._model_parallel_submesh:\n",
        "  num_partitions = np.prod(partitioner._model_parallel_submesh)\n",
        "else:\n",
        "  num_partitions = partitioner._num_partitions\n",
        "if jax.device_count() % num_partitions != 0:\n",
        "  raise ValueError(\n",
        "    \"The number of devices available must be a multiple of the number of\",\n",
        "    f\" partitions. There are {jax.device_count()} devices available, but\",\n",
        "    f\" the number of partitions is set to {num_partitions}. Please\",\n",
        "    \" provide a different number of partitions.\")\n",
        "\n",
        "# 4.) Create a Checkpoint Manager\n",
        "# a.) Define CheckpointCfg wrappers.\n",
        "save_checkpoint_cfg = utils.SaveCheckpointConfig(\n",
        "        dtype=dtype,\n",
        "        keep=5, # The number of checkpoints to keep in the output_dir.\n",
        "        save_dataset=False)\n",
        "restore_checkpoint_cfg = utils.RestoreCheckpointConfig(\n",
        "        dtype=dtype,\n",
        "        mode=restore_mode,\n",
        "        path=checkpoint_path)\n",
        "\n",
        "# b.) Define a train state initializer, which will help us get information about the\n",
        "# TrainState shape.\n",
        "train_state_initializer = utils.TrainStateInitializer(\n",
        "        optimizer_def=model.optimizer_def,\n",
        "        init_fn=model.get_initial_variables,\n",
        "        input_shapes=input_shapes,\n",
        "        input_types=None,\n",
        "        partitioner=partitioner)\n",
        "\n",
        "# c.) Define the checkpoint manager.\n",
        "checkpoint_manager = utils.LegacyCheckpointManager(\n",
        "        save_cfg=save_checkpoint_cfg,\n",
        "        restore_cfg=restore_checkpoint_cfg,\n",
        "        train_state_shape=train_state_initializer.global_train_state_shape,\n",
        "        partitioner=partitioner,\n",
        "        ds_iter=None,\n",
        "        model_dir=output_dir)\n",
        "\n",
        "### 5.) Restore the Model from a Checkpoint, or Initialize from Scratch ###\n",
        "def get_state(rng):\n",
        "  return train_state_initializer.from_scratch(rng).state_dict()\n",
        "\n",
        "# a.) Try to restore a model from a checkpoint.\n",
        "train_state = checkpoint_manager.restore(\n",
        "  [restore_checkpoint_cfg.path],\n",
        "  restore_checkpoint_cfg,\n",
        "  utils.get_fallback_state(restore_checkpoint_cfg, get_state, init_rng)\n",
        ")\n",
        "\n",
        "# b.) If no checkpoint to restore, init from scratch.\n",
        "if train_state is None:\n",
        "  train_state = train_state_initializer.from_scratch(init_rng)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ib9aOi2xaCKQ"
      },
      "source": [
        "# Inference Deep Dive"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ANqpfv0lAVqL"
      },
      "source": [
        "**Defining a Batch of Examples to Run Inference On**\\\n",
        "Let's start by defining a batch of examples that we will get predictions and scores for.\n",
        "\n",
        "These examples should be a list of inputs; we don't need any targets, because we will eventually generate predictions. For this Colab, we'll use a set of natural text questions (and we will generate the answers)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yhhR0yDcAn7w"
      },
      "outputs": [],
      "source": [
        "examples = [\n",
        "    b'nq question: who has been appointed as the new chairman of sebi',\n",
        "    b'nq question: who wrote the book lion the witch and the wardrobe',\n",
        "    b'nq question: how many planes did japan lose at pearl harbor',\n",
        "    b'nq question: who does the voice of mcgruff the dog',\n",
        "    b'nq question: who sings the wheels in the sky keep on turning',\n",
        "    b'nq question: who voices regina in glitter force doki doki',\n",
        "    b'nq question: when did the us become allies with britain',\n",
        "    b'nq question: who won the rugby 7 in las vegas'\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WYV1LMS5taE9"
      },
      "source": [
        "We also define the required features of the examples. For this Colab, we will only require an `inputs` and `targets` entry, as defined below. `targets` will be empty for our examples, because we do not have any targets to provide at inference time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nj5I7YMotb9U"
      },
      "outputs": [],
      "source": [
        "output_features = {\n",
        "        \"inputs\":\n",
        "            seqio.Feature(\n",
        "                vocabulary=model.input_vocabulary, add_eos=True),\n",
        "        \"targets\":\n",
        "            seqio.Feature(\n",
        "                vocabulary=model.output_vocabulary, add_eos=True)\n",
        "    }\n",
        "features = dict(sorted(output_features.items()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IZuleutswi1H"
      },
      "source": [
        "Finally, we'll have to determine whether we want to get predictions or scores for this batch. For this example, we'll get predictions, which we'll denote by setting an inference mode variable to `PREDICT_WITH_AUX`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PbMPt5eBw-a4"
      },
      "outputs": [],
      "source": [
        "mode = InferenceType.PREDICT_WITH_AUX\n",
        "# Try replacing this variable with `InferenceType.SCORE` to produce scores."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mixLzcBkQOT_"
      },
      "source": [
        "Now, let's break down what the interactive model does to run inference.\n",
        "\n",
        "The `InteractiveModel` `infer_with_preprocessors()` method only performs three actions:\n",
        "\n",
        "\n",
        "1.   Convert the natural text examples into a tf.Dataset.\n",
        "2.   Define an `infer_fn`; depending on whether we want predictions or scores, this function will be equivalent to `model.predict_batch` or `model.score_batch`.\n",
        "3.   Extract inferences and return them.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ug0zJx2kQk6g"
      },
      "source": [
        "**Prepare the dataset** \\\n",
        "\n",
        "Preparing the data for inference is fairly straightforward; in fact, this is nearly the same data preparation that happens for training.\n",
        "\n",
        "First, we convert the natural text examples into a tf.Dataset and run any preprocessors; T5X has a helper function, `get_dataset_from_natural_text_examples`, that can do exactly that. For this example, the only preprocessing we will do is tokenization and appending an EOS token. If you are interested in learning more about preprocessors, please take a look at https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-intro.\n",
        "\n",
        "Finally, we  convert all features using the model's feature converter, pad all batches of data, and define an iterator over our data (this allows us to run inference on multiple batches of examples)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "chPomDFxQ6r3"
      },
      "outputs": [],
      "source": [
        "dataset = get_dataset_from_natural_text_examples(\n",
        "    examples,\n",
        "    preprocessors=[\n",
        "        seqio.preprocessors.tokenize,\n",
        "        seqio.preprocessors.append_eos\n",
        "    ],\n",
        "    task_feature_lengths=task_feature_lengths,\n",
        "    features=features)\n",
        "feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)\n",
        "model_dataset = feature_converter(\n",
        "    dataset, task_feature_lengths=task_feature_lengths)\n",
        "# Zip task and model features.\n",
        "infer_dataset = tf.data.Dataset.zip((dataset, model_dataset))\n",
        "# Create batches and index them.\n",
        "infer_dataset = infer_dataset.padded_batch(\n",
        "    batch_size, drop_remainder=False).enumerate()\n",
        "infer_dataset_iter: Iterator[Tuple[int, Any]] = iter(\n",
        "    infer_dataset.prefetch(tf.data.experimental.AUTOTUNE))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "toq2uz7dAfdL"
      },
      "source": [
        "**Define Infer Function** \\\n",
        "\n",
        "We'll define a helper function that runs inference on a single batch, making it easy to loop over this helper and run inference for multiple batches. This `infer_fn` can either get predictions or scores, depending on the mode we've previously set.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qv8RXHeXK6mk"
      },
      "outputs": [],
      "source": [
        "if mode == InferenceType.PREDICT_WITH_AUX:\n",
        "  infer_step = model.predict_batch_with_aux\n",
        "elif mode == InferenceType.SCORE:\n",
        "  infer_step = model.score_batch\n",
        "else:\n",
        "  raise ValueError(\"Mode must be `predict_with_aux`, or `score`,\"\n",
        "                  f\" but instead was {mode}.\")\n",
        "infer_fn = functools.partial(\n",
        "  utils.get_infer_fn(\n",
        "    infer_step=infer_step,\n",
        "    batch_size=batch_size,\n",
        "    train_state_axes=train_state_initializer.train_state_axes,\n",
        "    partitioner=partitioner),\n",
        "  train_state=train_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c-BoczLLzaGb"
      },
      "source": [
        "**Extract Inferences** \\\n",
        "\n",
        "Finally, we will extract inferences for each batch of examples provided. For each batch, we:\n",
        "\n",
        "1.  Unzip the dataset to get both the task dataset and the model dataset (the model dataset is what you get when you've passed the task dataset through the model feature converter).\n",
        "2.  Get an RNG for the batch.\n",
        "3.  Extract predictions and auxiliary values using the T5X helper, `_extract_tokens_and_aux_values`.\n",
        "4.  Decode the predictions using our vocabulary.\n",
        "5.  Accumulate predictions, aux values, and inputs across all of our batches.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GDAyEz9mzghq"
      },
      "outputs": [],
      "source": [
        "# Main Loop over \"batches\".\n",
        "all_inferences = []\n",
        "all_aux_values = {}\n",
        "for chunk, chunk_batch in infer_dataset_iter:\n",
        "  # Load the dataset for the next chunk. We can't use `infer_dataset_iter`\n",
        "  # directly since `infer_fn` needs to know the exact size of each chunk,\n",
        "  # which may be smaller for the final one.\n",
        "  chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch)\n",
        "  chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "  # Unzip chunk dataset in to pretokenized and model datasets.\n",
        "  task_dataset = chunk_dataset.map(\n",
        "      lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  model_dataset = chunk_dataset.map(\n",
        "      lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "  # Get a chunk-specific RNG key.\n",
        "  chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)\n",
        "\n",
        "  inferences = _extract_tokens_and_aux_values(\n",
        "      infer_fn(model_dataset.enumerate(), rng=chunk_rng))\n",
        "\n",
        "  predictions, aux_values = inferences\n",
        "  accumulated_inferences = []\n",
        "  for idx, inputs in task_dataset.enumerate().as_numpy_iterator():\n",
        "    prediction = predictions[idx]\n",
        "    # Decode predictions if applicable.\n",
        "    if mode == InferenceType.PREDICT_WITH_AUX:\n",
        "      prediction = features[\"targets\"].vocabulary.decode_tf(\n",
        "          tf.constant(prediction)).numpy()\n",
        "    accumulated_inferences.append((inputs, prediction))\n",
        "  all_inferences += accumulated_inferences\n",
        "  # Accumulate aux values over batches.\n",
        "  if not all_aux_values:\n",
        "    all_aux_values = aux_values\n",
        "  else:\n",
        "    for key, values in aux_values.items():\n",
        "      all_aux_values[key] += values\n",
        "print(all_inferences)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rIzrCUyWQcDZ"
      },
      "source": [
        "We can parse these predictions into a more readable format using the code below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f9_BPXG_QgJs"
      },
      "outputs": [],
      "source": [
        "for input, prediction in all_inferences:\n",
        "  print(f\"Input: {input['inputs_pretokenized']}\")\n",
        "  print(f\"Prediction: {prediction}\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AB0U_kfRNIyR"
      },
      "source": [
        "The code snippets above exactly replicate the `InteractiveModel` `infer_with_preprocessors()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running `interactive_model.infer_with_preprocessors(mode, examples, preprocessors=[seqio.preprocessors.tokenize, seqio.preprocessors.append_eos])`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lcDwmp_AxnOG"
      },
      "source": [
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-QR5LnmN4ikp"
      },
      "source": [
        "# Advanced Topics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CLstCKpP8Ge7"
      },
      "source": [
        "## T5X Inference Binaries and Other Advanced Features\n",
        "\n",
        "T5X offers inference binaries that have the same functionality as the InteractiveModel, with additional features as well (more advanced compiling, inference on TF Example files, prediction services, etc.). Importantly, these binaries are configured using [Gin](https://github.com/google/gin-config/blob/main/README.md); if you are not familiar with Gin, please take a look at this [Gin Primer](https://github.com/google-research/t5x/blob/main/docs/usage.md/gin) to get started.\n",
        "\n",
        "If you are familiar with Gin and interested in using the T5X inference binaries, we have provided a helper function, get_gin_config_from_interactive_model, which will take an InteractiveModel instance and generate the gin config that you can use to run the T5X inference binaries; this gin config will exactly reproduce the InteractiveModel inference functionality we've described above. We've provided an example below.\n",
        "\n",
        "Importantly, the InteractiveModel takes in a model, partitioner, and data, so we cannot generate Gin configs for these components. You can pass Gin config strings for the model and partitioner components to the helper function, as demonstrated below. Additionally, you can pass a SeqIO task containing your data to the helper function. See the section below if you are unfamiliar with SeqIO."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nqa49ZpjnRN1"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/google-research/google-research.git google_research"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rhgUZ0w6yQsE"
      },
      "outputs": [],
      "source": [
        "# Define an InteractiveModel instance, based on the `small` T5X EncoderDecoder model.\n",
        "t5_config = network.T5Config(\n",
        "    vocab_size=32128,\n",
        "    dtype='bfloat16',\n",
        "    emb_dim=512,\n",
        "    num_heads=6,\n",
        "    num_encoder_layers=8,\n",
        "    num_decoder_layers=8,\n",
        "    head_dim=64,\n",
        "    mlp_dim=1024,\n",
        "    mlp_activations=('gelu', 'linear'),\n",
        "    dropout_rate=0.0,\n",
        "    logits_via_embedding=False)\n",
        "module = network.Transformer(config=t5_config)\n",
        "model = t5x.models.EncoderDecoderModel(\n",
        "    module=module,\n",
        "    input_vocabulary=t5.data.get_default_vocabulary(),\n",
        "    output_vocabulary=t5.data.get_default_vocabulary(),\n",
        "    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),\n",
        "    decode_fn=functools.partial(\n",
        "        t5x.decoding.temperature_sample, temperature=1.0, topk=40))\n",
        "interactive_model = InteractiveModel(\n",
        "    batch_size=8,\n",
        "    task_feature_lengths={'inputs': 38, 'targets': 18},\n",
        "    output_dir='/tmp/output_dir',\n",
        "    partitioner=partitioning.PjitPartitioner(\n",
        "      num_partitions=1,\n",
        "      model_parallel_submesh=None,\n",
        "      logical_axis_rules=partitioning.standard_logical_axis_rules()),\n",
        "    model=model,\n",
        "    dtype='bfloat16',\n",
        "    restore_mode='specific',\n",
        "    checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000',\n",
        "    input_shapes={\n",
        "      'encoder_input_tokens': np.array([8, 38]),\n",
        "      'decoder_target_tokens': np.array([8, 18]),\n",
        "      'decoder_input_tokens': np.array([8, 18]),\n",
        "      'decoder_loss_weights': np.array([8, 18])\n",
        "    },\n",
        "    input_types=None)\n",
        "\n",
        "# Define Gin Config strings for the model, partitioner, and any imports.\n",
        "imports_str = \"\"\"from t5x import models\n",
        "from t5x import partitioning\n",
        "import t5.data.mixtures\n",
        "include 't5x/examples/t5/t5_1_1/small.gin'\n",
        "\n",
        "# Register necessary SeqIO Tasks/Mixtures.\n",
        "import google_research.t5_closed_book_qa.t5_cbqa.tasks\"\"\"\n",
        "partitioner_config = 'partitioning.PjitPartitioner.num_partitions = 2'\n",
        "\n",
        "gin_config_str = get_gin_config_from_interactive_model(\n",
        "  interactive_model=interactive_model,\n",
        "  script_type=T5XScriptType.INFERENCE,\n",
        "  task_name='closed_book_qa',\n",
        "  partitioner_config_str=partitioner_config,\n",
        "  model_config_str='',  # No config needed, since we just import the model.\n",
        "  imports_str=imports_str,\n",
        ")\n",
        "print(gin_config_str)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uGd1DxDT3gB7"
      },
      "source": [
        "Once you have generated the `gin_config_str` as above, you can write this string to a file and launch your inference experiment locally by running the following on commandline:\n",
        "\n",
        "\n",
        "```\n",
        "INFER_OUTPUT_DIR=\"/tmp/inference-model/\"\n",
        "python -m t5x.infer_unfragmented \\\n",
        "  --gin_file=${GIN_FILE_PATH} \\\n",
        "  --gin.INFER_OUTPUT_DIR=\\\"${INFER_OUTPUT_DIR}\\\" \\\n",
        "  --alsologtostderr\n",
        "```\n",
        "For more details on inference using the T5X inference binaries, please see the [Inference](https://github.com/google-research/t5x/blob/main/docs/usage.md/infer-seqio) tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wi29fMdv4mSr"
      },
      "source": [
        "## SeqIO\n",
        "\n",
        "If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.\n",
        "\n",
        "We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md/index) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.\n",
        "\n",
        "\u003ca href=\"https://colab.research.google.com/github/google-research/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
        "\n",
        "If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the code snippets above. We've provided an example of this bridge below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bM0nRIEFwyj_"
      },
      "outputs": [],
      "source": [
        "import google_research.t5_closed_book_qa.t5_cbqa.tasks\n",
        "batches = get_batches_from_seqio(\n",
        "        task_or_mixture_name='natural_questions_open',\n",
        "        split='validation',\n",
        "        batch_size=8,\n",
        "        num_batches=2,\n",
        "        seed=42)\n",
        "print(f\"Batches: {batches}\")\n",
        "# Train the interactive model on the provided batches.\n",
        "original_step = interactive_model.step\n",
        "_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)\n",
        "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Elt08160w03X"
      },
      "source": [
        "The `get_batches_from_seqio` bridge can take several constructor arguments:\n",
        "\n",
        "\n",
        "1.   `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).\n",
        "2.   `split`: the split of the Task/Mixture to read data from.\n",
        "3.   `batch_size`: how many examples should appear in each batch.\n",
        "4.   `num_batches`: the total number of batches to return.\n",
        "5.   `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \\\n",
        "The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \\\n",
        "However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.\n",
        "6.   `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.\n",
        "7.   `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fjKBCX39w0Xl"
      },
      "outputs": [],
      "source": [
        "import t5.data.tasks\n",
        "batches = get_batches_from_seqio(\n",
        "    task_or_mixture_name='c4_v220_span_corruption',\n",
        "    split='validation',\n",
        "    batch_size=8,\n",
        "    num_batches=1,\n",
        "    get_pretokenized_examples=False,\n",
        "    sequence_length=interactive_model._task_feature_lengths,\n",
        "    seed=42)\n",
        "batch = batches[0]  # We expect only a single batch.\n",
        "original_step = interactive_model.step\n",
        "interactive_model.train_step_with_preprocessors(\n",
        "        examples=batch, preprocessors=[])\n",
        "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Welcome to T5X: Inference Deep Dive",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1ZyKqEf1xpxGUxX9JeQ0VN94MQlWGFrHf",
          "timestamp": 1676340117712
        },
        {
          "file_id": "1hQO9MD6psZtTeqZyXPJIoUV0uzTa2qPg",
          "timestamp": 1662951508591
        },
        {
          "file_id": "1Akpc6pKlJB5rn5YYYFC9lw2OMk6oBzlQ",
          "timestamp": 1662754223629
        },
        {
          "file_id": "1rA8bgO2bJRoebAuS96Ji0RUhnawgBY4i",
          "timestamp": 1650477076639
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
